#include "ggml-cuda/common.cuh"
#include "ggml.h"
#include "topk-moe.cuh"

/*
    This kernel does the following:
    1. softmax over the logits per token [n_experts, n_tokens]
    2. argmax reduce over the top-k (n_experts_used) logits
    3. write weights + ids to global memory

    It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <size_t n_experts, bool normalize>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
                                                                  float *       weights,
                                                                  int32_t *     ids,
                                                                  const int     n_rows,
                                                                  const int     n_expert_used) {
    const int row = blockIdx.x * blockDim.y + threadIdx.y;
    if (row >= n_rows) {
        return;
    }

    logits += n_experts * row;
    weights += n_expert_used * row;
    ids += n_experts * row;

    constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;

    float logits_r[experts_per_thread];

#pragma unroll
    for (int i = 0; i < n_experts; i += WARP_SIZE) {
        const int expert        = i + threadIdx.x;
        logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
    }

    float max_val = logits_r[0];

#pragma unroll
    for (int i = 1; i < experts_per_thread; i++) {
        const float val = logits_r[i];
        max_val         = max(val, max_val);
    }

    max_val = warp_reduce_max(max_val);

    float wt[experts_per_thread];
    float tmp = 0.f;

#pragma unroll
    for (int i = 0; i < experts_per_thread; i++) {
        const float val = logits_r[i];
        wt[i]           = expf(val - max_val);
        tmp += wt[i];
    }

    tmp = warp_reduce_sum(tmp);

    const float inv_sum = 1.0f / tmp;
#pragma unroll
    for (int i = 0; i < experts_per_thread; i++) {
        wt[i] = wt[i] * inv_sum;
    }

    //at this point, each thread holds a portion of softmax,
    //we do the argmax reduce over n_expert_used, each time marking
    //the expert weight as -inf to exclude from the next iteration

    [[maybe_unused]] float sum_selected = 0;
    for (int k = 0; k < n_expert_used; k++) {
        float max_val    = wt[0];
        int   max_expert = threadIdx.x;

#pragma unroll
        for (int i = 1; i < experts_per_thread; i++) {
            const int expert = threadIdx.x + i * WARP_SIZE;
            if (expert < n_experts && wt[i] > max_val) {
                max_val    = wt[i];
                max_expert = expert;
            }
        }

#pragma unroll
        for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
            const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
            const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
            if (val > max_val) {
                max_val    = val;
                max_expert = expert;
            }
        }

        sum_selected += max_val;
        if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
            wt[max_expert / WARP_SIZE] = -INFINITY;

            weights[k] = max_val;
            ids[k]     = max_expert;
        }
    }

    if (!normalize) return;

    __syncthreads();

    float norm = 1/sum_selected;
    for (int k = threadIdx.x; k < n_expert_used; k += WARP_SIZE) {
        weights[k] *= norm;
    }

}

__launch_bounds__(4 * WARP_SIZE, 1) __global__ void simple_moe_cuda(const float * logits,
                                                                    float *       weights,
                                                                    int32_t *     ids,
                                                                    const int     n_rows,
                                                                    const int     n_experts) {
    const int row = blockIdx.x * blockDim.y + threadIdx.y;
    if (row >= n_rows) {
        return;
    }

    logits  += n_experts * row;
    weights += n_experts * row;
    ids     += n_experts * row;

    float max_val = -INFINITY;
#pragma unroll
    for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
        max_val = max(max_val, logits[i]);
        ids[i]  = i;
    }

    max_val = warp_reduce_max(max_val);

    float sum = 0;
#pragma unroll
    for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
        weights[i] = expf(logits[i] - max_val);
        sum += weights[i];
    }

    sum = warp_reduce_sum(sum);
    float norm = 1/sum;
#pragma unroll
    for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
        weights[i] *= norm;
    }
}

template <bool normalize>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
                                 const float *               logits,
                                 float *                     weights,
                                 int32_t *                   ids,
                                 const int                   n_rows,
                                 const int                   n_expert,
                                 const int                   n_expert_used) {
    const int    rows_per_block = 4;
    dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
    dim3         block_dims(WARP_SIZE, rows_per_block, 1);
    cudaStream_t stream = ctx.stream();

    if (n_expert_used == n_expert) {
        simple_moe_cuda<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert);
        return;
    }

    switch (n_expert) {
        case 1:
            topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 2:
            topk_moe_cuda<2, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 4:
            topk_moe_cuda<4, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 8:
            topk_moe_cuda<8, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 16:
            topk_moe_cuda<16, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 32:
            topk_moe_cuda<32, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 64:
            topk_moe_cuda<64, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 128:
            topk_moe_cuda<128, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 256:
            topk_moe_cuda<256, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        case 512:
            topk_moe_cuda<512, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
            break;
        default:
            GGML_ASSERT(false && "fatal error");
            break;
    }
}

void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
                           const ggml_tensor *         logits,
                           ggml_tensor *               weights,
                           ggml_tensor *               ids) {
    GGML_ASSERT(logits->type == GGML_TYPE_F32);
    GGML_ASSERT(weights->type == GGML_TYPE_F32);
    GGML_ASSERT(ids->type == GGML_TYPE_I32);

    const int n_experts = logits->ne[0];
    const int n_rows    = logits->ne[1];

    const float * logits_d  = (const float *) logits->src[0]->data;
    float *       weights_d = (float *) weights->data;
    int32_t *     ids_d     = (int32_t *) ids->data;

    GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);

    cudaStream_t stream = ctx.stream();

    if (weights->op == GGML_OP_DIV) {
        const int n_expert_used = weights->ne[0];
        launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
    } else {
        const int n_expert_used = weights->ne[1];
        launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
    }
}

bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
    float scale    = 1.0f;
    float max_bias = 0.0f;

    memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
    memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));

    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
        return false;
    }

    if (scale != 1.0f || max_bias != 0.0f) {
        return false;
    }

    // don't fuse when masks or sinks are present
    if (softmax->src[1] || softmax->src[2]) {
        return false;
    }

    const int n_expert = softmax->ne[0];
    // n_expert must be a power of 2
    if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
        return false;
    }

    return true;
}
